import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
#Helper functions to visualize
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None):
"""
Plot a heatmap of tensors using seaborn
"""
=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False)
sns.heatmap(tensor.cpu().numpy(), ax
ax.set_title(title)
ax.set_yticklabels([])
ax.set_xticklabels([])
def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8):
"""
A method that plots 4 matrices, the original tensor, the quantized tensor
the de-quantized tensor and the error tensor.
"""
# Get a figure of 4 plots
= plt.subplots(1, 4, figsize=(15, 4))
fig, axes
# Plot the first matrix
0], 'Original Tensor', cmap=ListedColormap(['white']))
plot_matrix(original_tensor, axes[
# Get the quantization range and plot the quantized tensor
= torch.iinfo(dtype).min, torch.iinfo(dtype).max
q_min, q_max 1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm')
plot_matrix(quantized_tensor, axes[
# Plot the de-quantized tensors
2], 'Dequantized Tensor', cmap='coolwarm')
plot_matrix(dequantized_tensor, axes[
# Get the quantization errors
= abs(original_tensor - dequantized_tensor)
q_error_tensor 3], 'Quantization Error Tensor', cmap=ListedColormap(['white']))
plot_matrix(q_error_tensor, axes[
fig.tight_layout()
plt.show()
def linear_q_with_scale_and_zero_point(
= torch.int8):
tensor, scale, zero_point, dtype
= tensor / scale + zero_point
scaled_and_shifted_tensor
= torch.round(scaled_and_shifted_tensor)
rounded_tensor
= torch.iinfo(dtype).min
q_min = torch.iinfo(dtype).max
q_max
= rounded_tensor.clamp(q_min,q_max).to(dtype)
q_tensor
return q_tensor
=torch.tensor(
test_tensor191.6, -13.5, 728.6],
[[92.14, 295.5, -184],
[0, 684.6, 245.5]])
[
= 3.5
scale = -70
zero_point
= linear_q_with_scale_and_zero_point(
quantized_tensor
test_tensor, scale, zero_point)
def linear_dequantization(quantized_tensor, scale, zero_point):
return scale * (quantized_tensor.float() - zero_point)
= linear_dequantization(
dequantized_tensor
quantized_tensor, scale, zero_point)
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
Quantization in Depth
This is completely based on Quantization in Depth
For the code part, you can checkout this link
Quantize and De-quantize a tensor
- Advantages of Quantization
- Smaller model
- Speed gains
- Memory bandwidth
- Faster operations
- GEMM: General Matrix Multiply(matrix to matrix multiplication)
- GEMV: General Matrix Vector Multiplication (matrix to vector multiplication)
- Challenges of Quantization
- Quantization error
- Retraining (Quantization Aware Training)
- Limited Hardware support
- Calibration dataset needed
- packing/unpacking
- getting q:-
- r = s(q-z) q = int(round(r/s+z))
Get the Scale and Zero-Point
- s = (r_max-r_min)[current_tensor_range]/(q_max-q_min)[datatype_range]
- z = int(round(q_min - r_min/s))
- z and quantized tensor are of the same type
- z is an integer because it represent zero(in the original ‘r’ range) with an integer in the quantized ‘q’ range
- if z goes out of range:-
- z < q_min:-
- z = q_min
- z > q_max:-
- z = q_max
- z < q_min:-
import torch
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
= torch.iinfo(dtype).min, torch.iinfo(dtype).max
q_min, q_max = tensor.min().item(), tensor.max().item()
r_min, r_max
= (r_max - r_min) / (q_max - q_min)
scale
= q_min - (r_min / scale)
zero_point
# clip the zero_point to fall in [quantized_min, quantized_max]
if zero_point < q_min:
= q_min
zero_point elif zero_point > q_max:
= q_max
zero_point else:
# round and cast to int
= int(round(zero_point))
zero_point
return scale, zero_point
Symmetric vs Asymmetrci Mode
- Assymetric Mode:-
- map [r_max, r_min] to [q_max, q_min]
- This is what we have implemnted above
- Symmetric Mode:-
- map [-r_max, r_max] to [-q_max, q_max]
- where r_max = max(|tensor|)
- map [-r_max, r_max] to [-q_max, q_max]
Hence, we can simplify the equation to:-
- q = int(round(r/s))
- s = r_max/q_max
import torch
def get_q_scale_symmetric(tensor, dtype=torch.int8):
= tensor.abs().max().item()
r_max = torch.iinfo(dtype).max
q_max
# return the scale
return r_max/q_max
def linear_q_symmetric(tensor, dtype=torch.int8):
= get_q_scale_symmetric(tensor)
scale
= linear_q_with_scale_and_zero_point(tensor,
quantized_tensor =scale,
scale# in symmetric quantization zero point is = 0
=0,
zero_point=dtype)
dtype
return quantized_tensor, scale
Trade-off
Utilization of quantized range:
- when using asymmetric quantization, the quantized range is fully utilized
- When symmetric mode, if the float range is biased towards one side, this will result in a quantized range where a part of the range is dedicated to values that we’ll never see.(e.g ReLU where the output is positive)
Simplicity:
- Symmetric mode is much simpler compared to asymmetric mode.
Memory: We don’t have to store zero-point for symmetric quantization
We use symmetric quantization for 8-bit, but as we go for lower bits such as 2 or 4 bits, we use asyyemtric quantization
Finer Granularity for more Precision
- Different granularities
- per tensor
- per channel (along an axis)
- per group (group n elements together)
- The more granular quantization is the more precise it will be.
Per Channel Quantization
- we usually use per channel quantization in int8
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
= r_tensor.shape[dim]
output_dim # store the scales
= torch.zeros(output_dim)
scale
for index in range(output_dim):
= r_tensor.select(dim, index)
sub_tensor = get_q_scale_symmetric(sub_tensor, dtype=dtype)
scale[index]
# reshape the scale
= [1] * r_tensor.dim()
scale_shape = -1
scale_shape[dim] = scale.view(scale_shape)
scale = linear_q_with_scale_and_zero_point(
quantized_tensor =scale, zero_point=0, dtype=dtype)
r_tensor, scale
return quantized_tensor, scale
Per Group Quantization
Group n(e.g. 32, 64, 128) elements together and quantize
Per group quantization can require a lot of memory
- Let’s say we want to quantize a tensor in 4-bit and we choose group_size=32, symmetric mode(z=0), and we store the scales in FP16
- It means that we actually quantizing the tensor in 4.5 bits since we have:
- 4 bit(each element is stored in 4 bit)
- 16/32 bit (scale in 16 bits for every 32 elements)
def linear_q_symmetric_per_group(tensor, group_size,
=torch.int8):
dtype
= tensor.shape
t_shape assert t_shape[1] % group_size == 0
assert tensor.dim() == 2
= tensor.view(-1, group_size)
tensor
= linear_q_symmetric_per_channel(
quantized_tensor, scale =0, dtype=dtype)
tensor, dim
= quantized_tensor.view(t_shape)
quantized_tensor
return quantized_tensor, scale
def linear_dequantization_per_group(quantized_tensor, scale,
group_size):
= quantized_tensor.shape
q_shape = quantized_tensor.view(-1, group_size)
quantized_tensor
= linear_dequantization(quantized_tensor,
dequantized_tensor 0)
scale,
= dequantized_tensor.view(q_shape)
dequantized_tensor
return dequantized_tensor
Quantizing Weights and Activations for Inference
- Depending on what we quantize, the storage and the computation are not the same.
- W8A32
- If weights are quantized but not the activations, then computation is done floating point (FP16,FP32, BF16)
- We need to dequantize the weights to perform the floating point computation (cast to float32)
- W8A8
- Both are quantized
- Computation is integer based but not supported by all hardware
Custom Build an 8-Bit Quantizer
#W8A16LinearLayer
def w8_a16_forward(weight, input, scales, bias=None):
= weight.to(input.dtype)
casted_weights = F.linear(input, casted_weights) * scales
output
if bias is not None:
= output + bias
output
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
=True, dtype=torch.float32):
biassuper().__init__()
self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
=torch.int8))
).to(dtype
try:
1, 1)
W8A16LinearLayer(
except Exception as error:
print("\033[91m", type(error).__name__, ": ", error, "\033[0m")
RuntimeError : Only Tensors of floating point and complex dtype can require gradients
When we create nn.Parameters, pytorch expects that parameter where it’s able to compute gradients on it.
The issue is that with PyTorch, you can’t explicitly compute gradients on INT8 tensors.
So above code snippet will give an error saying that only tensors of floating point and complex dtype can require gradients.
So the right approach to save INT8 weights is instead of saving attributes as being an endless parameter, is to call the method called register buffer.
This way instead of storing a parameter, we just store a buffer, meaning we don’t need to compute gradients on the tensor.
You can initialize it with whatever dtype you want.
import torch
import torch.nn as nn
import torch.nn.functional as F
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
=True, dtype=torch.float32):
biassuper().__init__()
self.register_buffer(
"int8_weights",
torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8
)
)
self.register_buffer("scales",
=dtype)) # We are intereseted in inference only
torch.randn((out_features), dtype
if bias:
self.register_buffer("bias",
1, out_features),
torch.randn((=dtype))
dtype
else:
self.bias = None
def forward(self, input):
return w8_a16_forward(self.int8_weights,
input, self.scales, self.bias)
Quantize a Base Model
import torch
import torch.nn as nn
import torch.nn.functional as F
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
=True, dtype=torch.float32):
biassuper().__init__()
self.register_buffer(
"int8_weights",
torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8
)
)
self.register_buffer("scales",
=dtype))
torch.randn((out_features), dtype
if bias:
self.register_buffer("bias",
1, out_features),
torch.randn((=dtype))
dtype
else:
self.bias = None
def quantize(self, weights):
= weights.clone().to(torch.float32)
w_fp32
= w_fp32.abs().max(dim=-1).values / 127
scales = scales.to(weights.dtype)
scales
= torch.round(weights
int8_weights /scales.unsqueeze(1)).to(torch.int8)
self.int8_weights = int8_weights
self.scales = scales
def forward(self, input):
return w8_a16_forward(self.int8_weights,
input, self.scales, self.bias)
= W8A16LinearLayer(4, 8)
module print("Weights before:\n" , module.int8_weights)
= torch.randn((4, 8), dtype=torch.bfloat16)
random_matrix
module.quantize(random_matrix)print("Weights After:\n" , module.int8_weights)
print("Average quantiation error:-",(random_matrix - module.int8_weights
* module.scales.unsqueeze(1)).abs().mean())
Weights before:
tensor([[ -55, 55, 41, 27],
[ 1, -81, 34, 126],
[ -12, 35, -104, 100],
[ 63, -58, -109, 10],
[ 116, -24, -1, -124],
[ -48, 4, -9, 23],
[ -63, -48, -18, -26],
[ 9, -82, 90, 9]], dtype=torch.int8)
Weights After:
tensor([[-126, -68, 46, -70, 126, 9, 22, -75],
[ -10, 100, 96, 48, -68, 126, -20, 15],
[ 47, 102, -14, 27, -127, 76, -31, 25],
[ 40, -55, 62, 127, -10, 6, -34, -66]], dtype=torch.int8)
Average quantiation error:- tensor(0.0039, dtype=torch.bfloat16)
Replace PyTorch layers with Quantized Layers
- For language models, it better to not quantize the last layer.
import torch
import torch.nn as nn
import torch.nn.functional as F
def replace_linear_with_target(module,
target_class, module_name_to_exclude):for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
= child.bias
old_bias
= target_class(child.in_features,
new_module
child.out_features, is not None,
old_bias
child.weight.dtype)setattr(module, name, new_module)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target(
child, target_class, module_name_to_exclude)
def replace_linear_with_target_and_quantize(module,
target_class, module_name_to_exclude):for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
= child.bias
old_bias = child.weight
old_weight
= target_class(child.in_features,
new_module
child.out_features, is not None,
old_bias
child.weight.dtype)setattr(module, name, new_module) # current module is replaced by new_module
getattr(module, name).quantize(old_weight)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target_and_quantize(child,
target_class, module_name_to_exclude)
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(1, 1)
# Try with bias
self.linear_1 = nn.Linear(1, 1)
# Try without bias
self.linear_2 = nn.Linear(1, 1, bias=False)
# Lm prediction head
self.lm_head = nn.Linear(1, 1, bias=False)
= DummyModel()
model_1 = DummyModel()
model_2 "lm_head"])
replace_linear_with_target_and_quantize(model_1, W8A16LinearLayer, [print("model_1",model_1)
replace_linear_with_target_and_quantize(model_2, W8A16LinearLayer, [])print("model_2",model_2)
model_1 DummyModel(
(emb): Embedding(1, 1)
(linear_1): W8A16LinearLayer()
(linear_2): W8A16LinearLayer()
(lm_head): Linear(in_features=1, out_features=1, bias=False)
)
model_2 DummyModel(
(emb): Embedding(1, 1)
(linear_1): W8A16LinearLayer()
(linear_2): W8A16LinearLayer()
(lm_head): W8A16LinearLayer()
)
Quantize any Open Source PyTorch Model
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
= "Salesforce/codegen-350M-mono"
model_id
= AutoModelForCausalLM.from_pretrained(model_id,
model =torch.bfloat16,
torch_dtype=True)
low_cpu_mem_usage= AutoTokenizer.from_pretrained(model_id)
tokenizer
= pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe print(pipe("def hello_world():", max_new_tokens=20, do_sample=False)[0]["generated_text"])
print("Model before:\n\n", model)
replace_linear_with_target_and_quantize(model, "lm_head"])
W8A16LinearLayer, [
print("Model after:\n\n", model)
print(pipe("def hello_world():", max_new_tokens=20,
=False)[0]["generated_text"]) do_sample
Some weights of the model checkpoint at Salesforce/codegen-350M-mono were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.3.attn.causal_mask', 'transformer.h.4.attn.causal_mask', 'transformer.h.5.attn.causal_mask', 'transformer.h.6.attn.causal_mask', 'transformer.h.7.attn.causal_mask', 'transformer.h.8.attn.causal_mask', 'transformer.h.9.attn.causal_mask']
- This IS expected if you are initializing CodeGenForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CodeGenForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def hello_world():
print("Hello World")
hello_world()
# 파
Model before:
CodeGenForCausalLM(
(transformer): CodeGenModel(
(wte): Embedding(51200, 1024)
(drop): Dropout(p=0.0, inplace=False)
(h): ModuleList(
(0-19): 20 x CodeGenBlock(
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(attn): CodeGenAttention(
(attn_dropout): Dropout(p=0.0, inplace=False)
(resid_dropout): Dropout(p=0.0, inplace=False)
(qkv_proj): Linear(in_features=1024, out_features=3072, bias=False)
(out_proj): Linear(in_features=1024, out_features=1024, bias=False)
)
(mlp): CodeGenMLP(
(fc_in): Linear(in_features=1024, out_features=4096, bias=True)
(fc_out): Linear(in_features=4096, out_features=1024, bias=True)
(act): NewGELUActivation()
(dropout): Dropout(p=0.0, inplace=False)
)
)
)
(ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Model after:
CodeGenForCausalLM(
(transformer): CodeGenModel(
(wte): Embedding(51200, 1024)
(drop): Dropout(p=0.0, inplace=False)
(h): ModuleList(
(0-19): 20 x CodeGenBlock(
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(attn): CodeGenAttention(
(attn_dropout): Dropout(p=0.0, inplace=False)
(resid_dropout): Dropout(p=0.0, inplace=False)
(qkv_proj): W8A16LinearLayer()
(out_proj): W8A16LinearLayer()
)
(mlp): CodeGenMLP(
(fc_in): W8A16LinearLayer()
(fc_out): W8A16LinearLayer()
(act): NewGELUActivation()
(dropout): Dropout(p=0.0, inplace=False)
)
)
)
(ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)
def hello_world():
print("Hello World")
# hello_world()
# def hello_
- Above code snippet modifies the model inplace
- Also don’t try to change the lm_head otherwise it will not give the desired results
- All the rounding errors can sum up once you start generating a lot of tokens, until may be all of these errors get super large so that it affects the model’s performance
Load your Quantized Weights from HuggingFace Hub
- The idea is to quantize weights on bigger instance and then push it back to huggingface. So that we don’t have to load and quantize again and again.
- Then use meta device from pytorch to load the skeleton of the model instead of loading the whole model itself.
- Replace the original layers with the quantized layers
- Load the quantized weights from huggingfacehub
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
= "facebook/opt-125m"
model_id
= AutoModelForCausalLM.from_pretrained(
model =torch.bfloat16, low_cpu_mem_usage=True)
model_id, torch_dtype= AutoTokenizer.from_pretrained(model_id)
tokenizer
replace_linear_with_target_and_quantize(model,
W8A16LinearLayer, "lm_head"])
[
model
= model.state_dict()
quantized_state_dict r"C:\wsl\random\models\quantized_state_dict.pth") torch.save(quantized_state_dict,
How to upload on HF
from huggingface_hub import HfApi, create_repo
YOUR_HF_USERNAME = “” your_repo_id = f”{YOUR_HF_USERNAME}/opt-125m-quantized-dlai”
api = HfApi()
create_repo(your_repo_id)
api.upload_file( path_or_fileobj=“quantized_state_dict.pth”, path_in_repo=“quantized_state_dict.pth”, repo_id=your_repo_id )
import torch
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig
= "facebook/opt-125m"
model_id = AutoConfig.from_pretrained(model_id)
config
with torch.device("meta"):
= OPTForCausalLM(config)
model
= AutoTokenizer.from_pretrained(model_id)
tokenizer
for param in model.parameters():
print(param)
break
Parameter containing:
tensor(..., device='meta', size=(50272, 768), requires_grad=True)
model
OPTForCausalLM(
(model): OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(50272, 768, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(layers): ModuleList(
(0-11): 12 x OPTDecoderLayer(
(self_attn): OPTSdpaAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(activation_fn): ReLU()
(self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(lm_head): Linear(in_features=768, out_features=50272, bias=False)
)
"lm_head"])
replace_linear_with_target(model, W8A16LinearLayer, [ model
OPTForCausalLM(
(model): OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(50272, 768, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(layers): ModuleList(
(0-11): 12 x OPTDecoderLayer(
(self_attn): OPTSdpaAttention(
(k_proj): W8A16LinearLayer()
(v_proj): W8A16LinearLayer()
(q_proj): W8A16LinearLayer()
(out_proj): W8A16LinearLayer()
)
(activation_fn): ReLU()
(self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(fc1): W8A16LinearLayer()
(fc2): W8A16LinearLayer()
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(lm_head): Linear(in_features=768, out_features=50272, bias=False)
)
If loading from HF
from huggingface_hub import hf_hub_download
state_dict_cache_path = hf_hub_download( “ybelkada/opt-125m-quantized-dlai”, “quantized_state_dict.pth” )
= torch.load(r"C:\wsl\random\models\quantized_state_dict.pth")
state_dict =True, assign=True) model.load_state_dict(state_dict, strict
C:\Users\rites\AppData\Local\Temp\ipykernel_32032\1779271493.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(r"C:\wsl\random\models\quantized_state_dict.pth")
<All keys matched successfully>
from transformers import pipeline
= pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe "Hello today I am", max_new_tokens=40)
pipe(
= pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe "Hello today I am giving a course about", max_new_tokens=10) pipe(
Device set to use cuda:0
Device set to use cuda:0
[{'generated_text': 'Hello today I am giving a course about the history of the world and the history of the'}]
Weights Packing
- Weights packing is important for storing quantized weights, because torch.int4 is not available as of today, so we need to store and load the weights in int8
- This is not ideal because:
- tensor will occupy 8-bit per datapoint and might add a considerable overhead for large models
- There would be no point of quantizing to 2/4 bits becuase we are still using 8-bit
- So, we need to pack values
- Consider a tensor with 4 values each with 2-bit(0,1,2,3) precision but stored in 8-bit
- tensor = torch.tensor([1,0,3,2], dtype=torch.uint8)
- 1:- 00000001
- 0:- 00000000
- 3:- 00000011
- 2:- 00000010
- We can pack all these values into a single 8-bit value as 177
- 177:- 10110001
- Adavantages:-
- It reflects the true memory footprint of the quantized weights Disadvantages:-
- The unpacked tensors need to be a shape with a multiple of 8//nbits
- It needs to unpack before performing an operation
Packing 2-bit Weights
import torch
# Example Tensor: [1, 0, 3, 2]
# 1 0 3 2 - 01 00 11 10
# Starting point of packed int8 Tensor
# [0000 0000]
##### First Iteration Start:
# packed int8 Tensor State: [0000 0000]
# 1 = 0000 0001
# 0000 0001
# No left shifts in the First Iteration
# After bit-wise OR operation between 0000 0000 and 0000 0001:
# packed int8 Tensor State: 0000 0001
##### First Iteration End
##### Second Iteration Start:
# packed int8 Tensor State: [0000 0001]
# 0 = 0000 0000
# 0000 0000
# 2 left shifts:
# [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000
# After bit-wise OR operation between 0000 0001 and 0000 0000:
# packed int8 Tensor State: 0000 0001
##### Second Iteration End
##### Third Iteration Start:
# packed int8 Tensor State: [0000 0001]
# 3 = 0000 0011
# 0000 0011
# 4 left shifts:
# [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100
# 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000
# After bit-wise OR operation between 0000 0001 and 0011 0000:
# packed int8 Tensor State: 0011 0001
##### Third Iteration End
##### Fourth Iteration Start:
# packed int8 Tensor State: [0011 0001]
# 2 = 0000 0010
# 0000 0010
# 6 left shifts:
# [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000
# 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000
# 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000
# After bit-wise OR operation between 0011 0001 and 1000 0000:
# packed int8 Tensor State: 1011 0001
##### Fourth Iteration End
# Final packed int8 Tensor State: [1011 0001]
def pack_weights(uint8tensor, bits):
if uint8tensor.shape[0] * bits % 8 != 0:
raise ValueError(f"The input shape needs to be a mutiple \
of {8 / bits} - got {uint8tensor.shape[0]}")
= uint8tensor.shape[0] * bits // 8
num_values
= 8 // bits
num_steps
= 0
unpacked_idx
= torch.zeros((num_values), dtype=torch.uint8)
packed_tensor
# 1 0 3 2 - 01 00 11 10
# [0000 0000] -> 0000 0001
# 0000 0001
# 0000 0000 - 0000 0000
# 0000 0011 - 0011 0000 - 0011 0001
# 1011 0001
for i in range(num_values):
for j in range(num_steps):
|= uint8tensor[unpacked_idx] << (bits * j)
packed_tensor[i] += 1
unpacked_idx return packed_tensor
= torch.tensor([1, 0, 3, 2],
unpacked_tensor =torch.uint8)
dtype2)
pack_weights(unpacked_tensor, = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3],
unpacked_tensor =torch.uint8)
dtype2) pack_weights(unpacked_tensor,
tensor([177, 255], dtype=torch.uint8)
Unpacking 2-Bit Weights
import torch
# Example Tensor: [10110001]
# Which was Originally: 1 0 3 2 - 01 00 11 10
# Starting point of unpacked Tensor
# [00000000 00000000 00000000 00000000]
##### First Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 01 from [101100 01]
# No right shifts in the First Iteration
# After bit-wise OR operation between 00000000 and 10110001:
# [10110001 00000000 00000000 00000000]
# unpacked Tensor state: [10110001 00000000 00000000 00000000]
##### First Iteration End
##### Second Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 00 from [1011 00 01]
# 2 right shifts:
# [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
# After bit-wise OR operation between 00000000 and 00101100:
# [10110001 00101100 00000000 00000000]
# unpacked Tensor state: [10110001 00101100 00000000 00000000]
##### Second Iteration End
##### Third Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 11 from [10 11 0001]
# 4 right shifts:
# [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
# 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
# After bit-wise OR operation between 00000000 and 00001011:
# [10110001 00101100 00001011 00000000]
# unpacked Tensor state: [10110001 00101100 00001011 00000000]
##### Third Iteration End
##### Fourth Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 10 from [10 110001]
# 6 right shifts:
# [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
# 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
# 00001011 (5 shift)-> 00000101 (6 shift)-> 00000010
# After bit-wise OR operation between 00000000 and 00000010:
# [10110001 00101100 00001011 00000010]
# unpacked Tensor state: [10110001 00101100 00001011 00000010]
##### Fourth Iteration End
# Last step: Perform masking (bit-wise AND operation)
# Mask: 00000011
# Bit-wise AND operation between
# unpacked Tensor and 00000011
# [10110001 00101100 00001011 00000010] <- unpacked tensor
# [00000011 00000011 00000011 00000011] <- Mask
# [00000001 00000000 00000011 00000010] <- Result
# Final
# unpacked Tensor state: [00000001 00000000 00000011 00000010]
def unpack_weights(uint8tensor, bits):
= uint8tensor.shape[0] * 8 // bits
num_values
= 8 // bits
num_steps
= torch.zeros((num_values), dtype=torch.uint8)
unpacked_tensor
= 0
unpacked_idx
# 1 0 3 2 - 01 00 11 10
# [00000000 00000000 00000000 00000000]
# [10110001 00101100 00001011 00000010]
# [00000001 00000000 00000011 00000010]
# 10110001
# 00000011
# 00000001
# 1: [10110001]
# 2: [00101100]
# 3: [00001011]
= 2 ** bits - 1
mask
for i in range(uint8tensor.shape[0]):
for j in range(num_steps):
|= uint8tensor[i] >> (bits * j)
unpacked_tensor[unpacked_idx] += 1
unpacked_idx
&= mask
unpacked_tensor return unpacked_tensor
= torch.tensor([177, 255],
unpacked_tensor =torch.uint8)
dtype# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]
2) unpack_weights(unpacked_tensor,
tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)
Beyond Linear Qauntization
Emergent features at scale:- Simply some characteristics or features which appear at scale, when model is large.
Features predicted by the model meaning the magnitude of the hidden states started to get large thus making the classic quantization schemes quite obsolete, which led to classic linear quantization algorithms just failing on these models.
Now how to deal with outlier features for LLMs
Outlier features simply means hidden states with large magnitude.
So there are some interesting papers such as LLM.int8, SmoothQuant, AWQ.
- LLM.int8 separates the matmul in two steps:-
- For non-outliers (smaller values)
- Perform matmul in int8, then dequantize it.
- For outliers (larger values)
- Perform matmul in classical way(basically in the dtype of hidden states usually half precision and then you combine these results)
- For non-outliers (smaller values)
- SmoothQuant
- Applies A8W8 scheme(quantize weights and activations)
- Given an input it determines some factor and use it to quantize.
- migrates the scale variance from activations to weights to reduce the quantization difficulty of activations.
- the smoothed activation and the adjusted weight are both easy to quantize.
- AWQ
- Used a calibration dataset to find out which weights could be responsible of generating outlier features called salient weights.
- and then use that information to scale the weights before quantization and also use that scale during inference to rescale the input as well.
- LLM.int8 separates the matmul in two steps:-
Recent SOTA quantization methods:
- LL.INT8
- GPTQ
- SmoothQuant
- QLoRA
- AWQ
- QuIP#
- HQQ
- AQLM
- ………..
Challenges of Quantization
- Retraining (Quantization Aware Training) [less explored]
- Limited Hardware support
- Calibration dataset needed
- packing/unpacking
Some Other resources
- MIT Han Lab
- Huggingface transformers quantization docs/blogposts
- llama.cpp discussions
- Reddit LocalLlama